import json
import torch
from sentence_transformers import util
from sentence_transformers import SentenceTransformer
import sys

with open("PQ_emb_saves/exmQuery_oaiemb_dict.json", "r") as file:
    exm_query_oaiemb_dict = json.load(file)
with open("PQ_emb_saves/testQuery_oaiemb_dict.json", "r") as file:
    test_query_oaiemb_dict = json.load(file)
with open("PQ_emb_saves/docentrep_oaiemb_mapping.json", "r") as file:
    docentrep_oaiemb_dict = json.load(file)
with open("PQ_emb_saves/sbert_emb_mapping.json", "r") as file:
    sbert_emb_mapping = json.load(file)

def get_oai_sim(test_query, exm_query_list):
    test_query_emb = torch.tensor(test_query_oaiemb_dict[test_query])
    exm_query_list_emb = torch.tensor([exm_query_oaiemb_dict[exm] for exm in exm_query_list])
    similarities = util.pytorch_cos_sim(test_query_emb, exm_query_list_emb)[0]
    return similarities

def get_oaisim_for_sbertTraining(exm_query, docent_rep_list):
    exm_query_emb = torch.tensor(exm_query_oaiemb_dict[exm_query])
    docent_rep_list_emb = torch.tensor([docentrep_oaiemb_dict[rep] for rep in docent_rep_list])
    similarities = util.pytorch_cos_sim(exm_query_emb, docent_rep_list_emb)[0]
    return similarities

def get_oaisim_for_sbertValidation(test_query, docent_rep_list):
    exm_query_emb = torch.tensor(test_query_oaiemb_dict[test_query])
    docent_rep_list_emb = torch.tensor([docentrep_oaiemb_dict[rep] for rep in docent_rep_list])
    similarities = util.pytorch_cos_sim(exm_query_emb, docent_rep_list_emb)[0]
    return similarities

def get_sbertEmbSim(query, query_list):
    query_emb = torch.tensor(sbert_emb_mapping[query])
    query_list_emb = torch.tensor([sbert_emb_mapping[quer] for quer in query_list])
    similarities = util.pytorch_cos_sim(query_emb, query_list_emb)[0]
    return similarities

def get_docent_from_code(code: str, doc_dict: dict):
    docent_list = [pg_name for pg_name in doc_dict if pg_name in code]
    return docent_list

def calc_recall(ret_list, good_list):
    common = set(ret_list) & set(good_list)
    return len(common)/len(good_list)

def calc_precision(ret_list, good_list):
    common = set(ret_list) & set(good_list)
    return len(common)/len(ret_list)

def calc_jaccard(ret_list, good_list):
    ret_set = set(ret_list)
    good_set = set(good_list)
    intersection_len = len(ret_set.intersection(good_set))
    union_len = len(ret_set.union(good_set))
    return intersection_len/union_len


